import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from torch import optim


class SDEBase(abc.ABC):
    def __init__(self, T):
        self.T = T
        self.dt = 1 / T

    @abc.abstractmethod
    def drift(self, x_t, t):
        pass

    @abc.abstractmethod
    def dispersion(self, x_t, t):
        pass

    def dw(self, x):
        return torch.randn_like(x) * math.sqrt(self.dt)

    def reverse_ode(self, x, t, score):
        dx = (self.drift(x, t) - 0.5 * self.dispersion(x, t) ** 2 * score) * self.dt
        return x - dx

    def reverse_sde(self, x, t, score):
        dx = (
            self.drift(x, t) - self.dispersion(x, t) ** 2 * score
        ) * self.dt + self.dispersion(x, t) * self.dw(x) * (t > 0)
        return x - dx

    def forward_step(self, x, t):
        dx = self.drift(x, t) * self.dt + self.dispersion(x, t) * self.dw(x)
        return x + dx

    def forward(self, x_0):
        x_ = x_0
        for t in range(self.T):
            x_ = self.forward_step(x_, t)
        return x_

    def reverse(self, x_t, score, mode, state):
        for t in reversed(range(self.T)):
            score_value = score(x_t, torch.full((x_t.shape[0],), t, dtype=torch.long), state)
            if mode == "sde":
                x_t = self.reverse_sde(x_t, t, score_value)
            elif mode == "ode":
                x_t = self.reverse_ode(x_t, t, score_value)
        return x_t


def vp_beta_schedule(timesteps, dtype=torch.float32):
    t = np.arange(1, timesteps + 1)
    T = timesteps
    b_max = 10.0
    b_min = 0.1
    alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T**2)
    betas = 1 - alpha
    return torch.tensor(betas, dtype=dtype)

class MySDE(SDEBase):
    def __init__(self, T, schedule):
        super().__init__(T=T)

        if schedule == "vp":
            self.thetas = vp_beta_schedule(T)

        self.sigmas = torch.sqrt(2 * self.thetas)

        thetas_cumsum = torch.cumsum(self.thetas, dim=0)
        self.thetas_bar = thetas_cumsum * self.dt
        self.vars = 1 - torch.exp(-2 * self.thetas_bar)
        self.stds = torch.sqrt(self.vars)

    def drift(self, x_t, t):
        return -self.thetas[t] * x_t

    def dispersion(self, x_t, t):
        return self.sigmas[t]

    def compute_score_from_noise(self, noise, t):
        return -noise / self.stds[t]

    def generate_random_state(self, a_0):
        noise = torch.randn_like(a_0)
        t = torch.randint(0, self.T, (a_0.shape[0], 1)).long()
        a_t = a_0 * torch.exp(-self.thetas_bar[t]) + self.stds[t] * noise
        return a_t, t

    def ground_truth_score(self, a_t, t, a_0):
        return (a_0 * torch.exp(-self.thetas_bar[t]) - a_t) / self.vars[t]

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()


        self.encoder = nn.ModuleList()
        for feature in features:
            self.encoder.append(nn.Conv2d(in_channels, feature, kernel_size=3, stride=1, padding=1))
            self.encoder.append(nn.ReLU(inplace=True))
            self.encoder.append(nn.Conv2d(feature, feature, kernel_size=3, stride=1, padding=1))
            self.encoder.append(nn.ReLU(inplace=True))
            in_channels = feature

        self.decoder = nn.ModuleList()
        for feature in reversed(features[:-1]):
            self.decoder.append(nn.ConvTranspose2d(in_channels, feature, kernel_size=2, stride=2))
            self.decoder.append(nn.ReLU(inplace=True))
            self.decoder.append(nn.Conv2d(feature, feature, kernel_size=3, stride=1, padding=1))
            self.decoder.append(nn.ReLU(inplace=True))
            in_channels = feature

        self.final_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        skips = []
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)
        for layer in self.decoder:
            x = layer(x)
        return self.final_conv(x)

class DiffusionSDEPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, T, max_action):
        super().__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.T = T
        self.max_action = max_action

        self.model = UNet(in_channels=state_dim, out_channels=action_dim)
        self.sde = MySDE(T, 'vp')

    def score_fn(self, a_t, t, state):
        noise = self.model(a_t)
        return self.sde.compute_score_from_noise(noise, t[0])

    def sample(self, state, mode):
        noise = torch.randn(state.shape[0], self.action_dim)
        action = self.sde.reverse(noise, self.score_fn, mode=mode, state=state)
        return action.clamp_(-self.max_action, self.max_action)

    def loss(self, state, a_0):
        a_t, t = self.sde.generate_random_state(a_0)
        score_pred = self.sde.compute_score_from_noise(
            self.model(a_t), t
        )
        score_true = self.sde.ground_truth_score(a_t, t, a_0)
        return F.mse_loss(score_pred, score_true)

